001package org.granite.gravity.tomcat; 002 003import java.io.IOException; 004import java.security.MessageDigest; 005import java.security.NoSuchAlgorithmException; 006import java.util.ArrayList; 007import java.util.Collections; 008import java.util.Enumeration; 009import java.util.List; 010import java.util.Queue; 011import java.util.concurrent.ConcurrentLinkedQueue; 012 013import javax.servlet.ServletConfig; 014import javax.servlet.ServletException; 015import javax.servlet.ServletRequest; 016import javax.servlet.ServletRequestWrapper; 017import javax.servlet.http.HttpServletRequest; 018import javax.servlet.http.HttpServletResponse; 019import javax.servlet.http.HttpSession; 020 021import org.apache.catalina.connector.RequestFacade; 022import org.apache.catalina.util.Base64; 023import org.apache.catalina.websocket.Constants; 024import org.apache.catalina.websocket.StreamInbound; 025import org.apache.catalina.websocket.WebSocketServlet; 026import org.apache.tomcat.util.buf.B2CConverter; 027import org.apache.tomcat.util.res.StringManager; 028import org.granite.context.GraniteContext; 029import org.granite.gravity.Gravity; 030import org.granite.gravity.GravityManager; 031import org.granite.gravity.GravityServletUtil; 032import org.granite.logging.Logger; 033import org.granite.messaging.webapp.ServletGraniteContext; 034 035import flex.messaging.messages.CommandMessage; 036import flex.messaging.messages.Message; 037 038 039public class TomcatWebSocketServlet extends WebSocketServlet { 040 041 private static final long serialVersionUID = 1L; 042 043 private static final Logger log = Logger.getLogger(TomcatWebSocketServlet.class); 044 045 private static final byte[] WS_ACCEPT = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(B2CConverter.ISO_8859_1); 046 private static final StringManager sm = StringManager.getManager(Constants.Package); 047 048 private final Queue<MessageDigest> sha1Helpers = new ConcurrentLinkedQueue<MessageDigest>(); 049 050 @Override 051 public void init(ServletConfig config) throws ServletException { 052 super.init(config); 053 054 GravityServletUtil.init(config); 055 } 056 057 @Override 058 protected String selectSubProtocol(List<String> subProtocols) { 059 return subProtocols != null && subProtocols.contains("org.granite.gravity") ? "org.granite.gravity" : null; 060 } 061 062 @Override 063 protected StreamInbound createWebSocketInbound(String protocol, HttpServletRequest request) { 064 Gravity gravity = GravityManager.getGravity(getServletContext()); 065 TomcatWebSocketChannelFactory channelFactory = new TomcatWebSocketChannelFactory(gravity, getServletContext()); 066 067 try { 068 String connectMessageId = request.getHeader("connectId") != null ? request.getHeader("connectId") : request.getParameter("connectId"); 069 String clientId = request.getHeader("GDSClientId") != null ? request.getHeader("GDSClientId") : request.getParameter("GDSClientId"); 070 String clientType = request.getHeader("GDSClientType") != null ? request.getHeader("GDSClientType") : request.getParameter("GDSClientType"); 071 String sessionId = null; 072 HttpSession session = request.getSession(false); 073 if (session != null) { 074 ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), 075 getServletContext(), session, clientType); 076 077 sessionId = session.getId(); 078 } 079 else { 080 for (int i = 0; i < request.getCookies().length; i++) { 081 if ("JSESSIONID".equals(request.getCookies()[i].getName())) { 082 sessionId = request.getCookies()[i].getValue(); 083 break; 084 } 085 } 086 087 ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), 088 getServletContext(), sessionId, clientType); 089 } 090 091 log.info("WebSocket connection started %s clientId %s sessionId %s", protocol, clientId, sessionId); 092 093 CommandMessage pingMessage = new CommandMessage(); 094 pingMessage.setMessageId(connectMessageId != null ? connectMessageId : "OPEN_CONNECTION"); 095 pingMessage.setOperation(CommandMessage.CLIENT_PING_OPERATION); 096 if (clientId != null) 097 pingMessage.setClientId(clientId); 098 099 Message ackMessage = gravity.handleMessage(channelFactory, pingMessage); 100 101 TomcatWebSocketChannel channel = gravity.getChannel(channelFactory, (String)ackMessage.getClientId()); 102 103 if (!ackMessage.getClientId().equals(clientId)) 104 channel.setConnectAckMessage(ackMessage); 105 106 return channel.getStreamInbound(); 107 } 108 finally { 109 GraniteContext.release(); 110 } 111 } 112 113 @Override 114 protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { 115 116 // Information required to send the server handshake message 117 String key; 118 String subProtocol = null; 119 List<String> extensions = Collections.emptyList(); 120 121 if (!headerContainsToken(req, "upgrade", "websocket")) { 122 resp.sendError(HttpServletResponse.SC_BAD_REQUEST); 123 return; 124 } 125 126 if (!headerContainsToken(req, "connection", "upgrade")) { 127 resp.sendError(HttpServletResponse.SC_BAD_REQUEST); 128 return; 129 } 130 131 if (!headerContainsToken(req, "sec-websocket-version", "13")) { 132 resp.setStatus(426); 133 resp.setHeader("Sec-WebSocket-Version", "13"); 134 return; 135 } 136 137 key = req.getHeader("Sec-WebSocket-Key"); 138 if (key == null) { 139 resp.sendError(HttpServletResponse.SC_BAD_REQUEST); 140 return; 141 } 142 143 String origin = req.getHeader("Origin"); 144 if (!verifyOrigin(origin)) { 145 resp.sendError(HttpServletResponse.SC_FORBIDDEN); 146 return; 147 } 148 149 // Fix for Tomcat-7.0.29 bad header name (was Sec-WebSocket-Protocol-Client") 150 List<String> subProtocols = getTokensFromHeader(req, "Sec-WebSocket-Protocol"); 151 if (!subProtocols.isEmpty()) 152 subProtocol = selectSubProtocol(subProtocols); 153 154 // TODO Read client handshake - Sec-WebSocket-Extensions 155 156 // TODO Extensions require the ability to specify something (API TBD) 157 // that can be passed to the Tomcat internals and process extension 158 // data present when the frame is fragmented. 159 160 // If we got this far, all is good. Accept the connection. 161 resp.setHeader("Upgrade", "websocket"); 162 resp.setHeader("Connection", "upgrade"); 163 resp.setHeader("Sec-WebSocket-Accept", getWebSocketAccept(key)); 164 if (subProtocol != null) 165 resp.setHeader("Sec-WebSocket-Protocol", subProtocol); 166 167 if (!extensions.isEmpty()) { 168 // TODO 169 } 170 171 WsHttpServletRequestWrapper wrapper = new WsHttpServletRequestWrapper(req); 172 StreamInbound inbound = createWebSocketInbound(subProtocol, wrapper); 173 wrapper.invalidate(); 174 175 // Hack to avoid chunked transfer 176 resp.setContentLength(((TomcatWebSocketChannel.MessageInboundImpl)inbound).getAckLength()); 177 178 // Small hack until the Servlet API provides a way to do this. 179 ServletRequest inner = req; 180 // Unwrap the request 181 while (inner instanceof ServletRequestWrapper) 182 inner = ((ServletRequestWrapper)inner).getRequest(); 183 184 if (inner instanceof RequestFacade) 185 ((RequestFacade)inner).doUpgrade(inbound); 186 else 187 resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, sm.getString("servlet.reqUpgradeFail")); 188 } 189 190 191 private boolean headerContainsToken(HttpServletRequest req, 192 String headerName, String target) { 193 Enumeration<String> headers = req.getHeaders(headerName); 194 while (headers.hasMoreElements()) { 195 String header = headers.nextElement(); 196 String[] tokens = header.split(","); 197 for (String token : tokens) { 198 if (target.equalsIgnoreCase(token.trim())) { 199 return true; 200 } 201 } 202 } 203 return false; 204 } 205 206 private List<String> getTokensFromHeader(HttpServletRequest req, 207 String headerName) { 208 List<String> result = new ArrayList<String>(); 209 210 Enumeration<String> headers = req.getHeaders(headerName); 211 while (headers.hasMoreElements()) { 212 String header = headers.nextElement(); 213 String[] tokens = header.split(","); 214 for (String token : tokens) { 215 result.add(token.trim()); 216 } 217 } 218 return result; 219 } 220 221 private String getWebSocketAccept(String key) throws ServletException { 222 223 MessageDigest sha1Helper = sha1Helpers.poll(); 224 if (sha1Helper == null) { 225 try { 226 sha1Helper = MessageDigest.getInstance("SHA1"); 227 } catch (NoSuchAlgorithmException e) { 228 throw new ServletException(e); 229 } 230 } 231 232 sha1Helper.reset(); 233 sha1Helper.update(key.getBytes(B2CConverter.ISO_8859_1)); 234 String result = Base64.encode(sha1Helper.digest(WS_ACCEPT)); 235 236 sha1Helpers.add(sha1Helper); 237 238 return result; 239 } 240}