001package org.granite.gravity.tomcat;
002
003import java.io.IOException;
004import java.security.MessageDigest;
005import java.security.NoSuchAlgorithmException;
006import java.util.ArrayList;
007import java.util.Collections;
008import java.util.Enumeration;
009import java.util.List;
010import java.util.Queue;
011import java.util.concurrent.ConcurrentLinkedQueue;
012
013import javax.servlet.ServletConfig;
014import javax.servlet.ServletException;
015import javax.servlet.ServletRequest;
016import javax.servlet.ServletRequestWrapper;
017import javax.servlet.http.HttpServletRequest;
018import javax.servlet.http.HttpServletResponse;
019import javax.servlet.http.HttpSession;
020
021import org.apache.catalina.connector.RequestFacade;
022import org.apache.catalina.util.Base64;
023import org.apache.catalina.websocket.Constants;
024import org.apache.catalina.websocket.StreamInbound;
025import org.apache.catalina.websocket.WebSocketServlet;
026import org.apache.tomcat.util.buf.B2CConverter;
027import org.apache.tomcat.util.res.StringManager;
028import org.granite.context.GraniteContext;
029import org.granite.gravity.Gravity;
030import org.granite.gravity.GravityManager;
031import org.granite.gravity.GravityServletUtil;
032import org.granite.logging.Logger;
033import org.granite.messaging.webapp.ServletGraniteContext;
034
035import flex.messaging.messages.CommandMessage;
036import flex.messaging.messages.Message;
037
038
039public class TomcatWebSocketServlet extends WebSocketServlet {
040        
041        private static final long serialVersionUID = 1L;
042        
043        private static final Logger log = Logger.getLogger(TomcatWebSocketServlet.class);
044        
045    private static final byte[] WS_ACCEPT = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(B2CConverter.ISO_8859_1);
046    private static final StringManager sm = StringManager.getManager(Constants.Package);
047
048    private final Queue<MessageDigest> sha1Helpers = new ConcurrentLinkedQueue<MessageDigest>();
049        
050        @Override
051        public void init(ServletConfig config) throws ServletException {
052                super.init(config);
053                
054                GravityServletUtil.init(config);
055        }
056        
057        @Override
058        protected String selectSubProtocol(List<String> subProtocols) {
059                return subProtocols != null && subProtocols.contains("org.granite.gravity") ? "org.granite.gravity" : null;
060        }
061        
062        @Override
063        protected StreamInbound createWebSocketInbound(String protocol, HttpServletRequest request) {
064                Gravity gravity = GravityManager.getGravity(getServletContext());
065                TomcatWebSocketChannelFactory channelFactory = new TomcatWebSocketChannelFactory(gravity, getServletContext());
066                
067                try {
068                        String connectMessageId = request.getHeader("connectId") != null ? request.getHeader("connectId") : request.getParameter("connectId");
069                        String clientId = request.getHeader("GDSClientId") != null ? request.getHeader("GDSClientId") : request.getParameter("GDSClientId");
070                        String clientType = request.getHeader("GDSClientType") != null ? request.getHeader("GDSClientType") : request.getParameter("GDSClientType");
071                        String sessionId = null;
072                        HttpSession session = request.getSession(false);
073                        if (session != null) {
074                        ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), 
075                                getServletContext(), session, clientType);
076                        
077                                sessionId = session.getId();
078                        }
079                        else {
080                                for (int i = 0; i < request.getCookies().length; i++) {
081                                        if ("JSESSIONID".equals(request.getCookies()[i].getName())) {
082                                                sessionId = request.getCookies()[i].getValue();
083                                                break;
084                                        }
085                                }                               
086                                
087                                ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), 
088                                getServletContext(), sessionId, clientType); 
089                        }
090                        
091                        log.info("WebSocket connection started %s clientId %s sessionId %s", protocol, clientId, sessionId);
092                        
093                        CommandMessage pingMessage = new CommandMessage();
094                        pingMessage.setMessageId(connectMessageId != null ? connectMessageId : "OPEN_CONNECTION");
095                        pingMessage.setOperation(CommandMessage.CLIENT_PING_OPERATION);
096                        if (clientId != null)
097                                pingMessage.setClientId(clientId);
098                        
099                        Message ackMessage = gravity.handleMessage(channelFactory, pingMessage);
100                        
101                        TomcatWebSocketChannel channel = gravity.getChannel(channelFactory, (String)ackMessage.getClientId());
102                        
103                        if (!ackMessage.getClientId().equals(clientId))
104                                channel.setConnectAckMessage(ackMessage);
105                        
106                        return channel.getStreamInbound();
107                }
108                finally {
109                        GraniteContext.release();
110                }
111        }
112        
113    @Override
114    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
115
116        // Information required to send the server handshake message
117        String key;
118        String subProtocol = null;
119        List<String> extensions = Collections.emptyList();
120
121        if (!headerContainsToken(req, "upgrade", "websocket")) {
122            resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
123            return;
124        }
125
126        if (!headerContainsToken(req, "connection", "upgrade")) {
127            resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
128            return;
129        }
130
131        if (!headerContainsToken(req, "sec-websocket-version", "13")) {
132            resp.setStatus(426);
133            resp.setHeader("Sec-WebSocket-Version", "13");
134            return;
135        }
136
137        key = req.getHeader("Sec-WebSocket-Key");
138        if (key == null) {
139            resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
140            return;
141        }
142
143        String origin = req.getHeader("Origin");
144        if (!verifyOrigin(origin)) {
145            resp.sendError(HttpServletResponse.SC_FORBIDDEN);
146            return;
147        }
148        
149        // Fix for Tomcat-7.0.29 bad header name (was Sec-WebSocket-Protocol-Client")
150        List<String> subProtocols = getTokensFromHeader(req, "Sec-WebSocket-Protocol");
151        if (!subProtocols.isEmpty())
152            subProtocol = selectSubProtocol(subProtocols);
153
154        // TODO Read client handshake - Sec-WebSocket-Extensions
155
156        // TODO Extensions require the ability to specify something (API TBD)
157        //      that can be passed to the Tomcat internals and process extension
158        //      data present when the frame is fragmented.
159
160        // If we got this far, all is good. Accept the connection.
161        resp.setHeader("Upgrade", "websocket");
162        resp.setHeader("Connection", "upgrade");
163        resp.setHeader("Sec-WebSocket-Accept", getWebSocketAccept(key));
164        if (subProtocol != null)
165            resp.setHeader("Sec-WebSocket-Protocol", subProtocol);
166        
167        if (!extensions.isEmpty()) {
168            // TODO
169        }
170
171        WsHttpServletRequestWrapper wrapper = new WsHttpServletRequestWrapper(req);
172        StreamInbound inbound = createWebSocketInbound(subProtocol, wrapper);
173        wrapper.invalidate();
174        
175        // Hack to avoid chunked transfer
176        resp.setContentLength(((TomcatWebSocketChannel.MessageInboundImpl)inbound).getAckLength());
177        
178        // Small hack until the Servlet API provides a way to do this.
179        ServletRequest inner = req;
180        // Unwrap the request
181        while (inner instanceof ServletRequestWrapper)
182            inner = ((ServletRequestWrapper)inner).getRequest();
183        
184        if (inner instanceof RequestFacade)
185            ((RequestFacade)inner).doUpgrade(inbound);
186        else
187            resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, sm.getString("servlet.reqUpgradeFail"));
188    }
189    
190    
191    private boolean headerContainsToken(HttpServletRequest req,
192            String headerName, String target) {
193        Enumeration<String> headers = req.getHeaders(headerName);
194        while (headers.hasMoreElements()) {
195            String header = headers.nextElement();
196            String[] tokens = header.split(",");
197            for (String token : tokens) {
198                if (target.equalsIgnoreCase(token.trim())) {
199                    return true;
200                }
201            }
202        }
203        return false;
204    }
205    
206    private List<String> getTokensFromHeader(HttpServletRequest req,
207            String headerName) {
208        List<String> result = new ArrayList<String>();
209
210        Enumeration<String> headers = req.getHeaders(headerName);
211        while (headers.hasMoreElements()) {
212            String header = headers.nextElement();
213            String[] tokens = header.split(",");
214            for (String token : tokens) {
215                result.add(token.trim());
216            }
217        }
218        return result;
219    }
220    
221    private String getWebSocketAccept(String key) throws ServletException {
222
223        MessageDigest sha1Helper = sha1Helpers.poll();
224        if (sha1Helper == null) {
225            try {
226                sha1Helper = MessageDigest.getInstance("SHA1");
227            } catch (NoSuchAlgorithmException e) {
228                throw new ServletException(e);
229            }
230        }
231
232        sha1Helper.reset();
233        sha1Helper.update(key.getBytes(B2CConverter.ISO_8859_1));
234        String result = Base64.encode(sha1Helper.digest(WS_ACCEPT));
235
236        sha1Helpers.add(sha1Helper);
237
238        return result;
239    }
240}