001package org.granite.gravity.tomcat;
002
003import java.io.ByteArrayInputStream;
004import java.io.ByteArrayOutputStream;
005import java.io.IOException;
006import java.io.ObjectInput;
007import java.io.ObjectOutput;
008import java.nio.ByteBuffer;
009import java.nio.CharBuffer;
010import java.util.Arrays;
011import java.util.LinkedList;
012
013import javax.servlet.ServletContext;
014
015import org.apache.catalina.websocket.MessageInbound;
016import org.apache.catalina.websocket.StreamInbound;
017import org.apache.catalina.websocket.WsOutbound;
018import org.granite.context.GraniteContext;
019import org.granite.gravity.AbstractChannel;
020import org.granite.gravity.AsyncHttpContext;
021import org.granite.gravity.Gravity;
022import org.granite.gravity.GravityConfig;
023import org.granite.logging.Logger;
024import org.granite.messaging.webapp.ServletGraniteContext;
025
026import flex.messaging.messages.AsyncMessage;
027import flex.messaging.messages.Message;
028
029
030public class TomcatWebSocketChannel extends AbstractChannel {
031        
032        private static final Logger log = Logger.getLogger(TomcatWebSocketChannel.class);
033        
034        private StreamInbound streamInbound = new MessageInboundImpl();
035        private ServletContext servletContext;
036        private WsOutbound connection;
037        private byte[] connectAckMessage;
038
039        
040        public TomcatWebSocketChannel(Gravity gravity, String id, TomcatWebSocketChannelFactory factory, ServletContext servletContext, String clientType) {
041        super(gravity, id, factory, clientType);
042        this.servletContext = servletContext;
043    }
044        
045        public void setConnectAckMessage(Message ackMessage) {
046                try {
047                        // Return an acknowledge message with the server-generated clientId
048                        connectAckMessage = serialize(getGravity(), new Message[] { ackMessage });                      
049                }
050                catch (IOException e) {
051                        throw new RuntimeException("Could not send connect acknowledge", e);
052                }
053        }
054        
055        public StreamInbound getStreamInbound() {
056                return streamInbound;
057        }
058        
059        public class MessageInboundImpl extends MessageInbound {
060                
061                public MessageInboundImpl() {
062                }
063
064                @Override
065                protected void onOpen(WsOutbound outbound) {                    
066                        connection = outbound;
067                        
068                        log.debug("WebSocket connection onOpen");
069                        
070                        if (connectAckMessage == null)
071                                return;
072                        
073                        try {
074                        ByteBuffer buf = ByteBuffer.wrap(connectAckMessage);
075                                connection.writeBinaryMessage(buf);
076                        }
077                        catch (IOException e) {
078                                throw new RuntimeException("Could not send connect acknowledge", e);
079                        }
080                        
081                        connectAckMessage = null;               
082                }
083
084                @Override
085                public void onClose(int closeCode) {
086                        log.debug("WebSocket connection onClose %d", closeCode);
087                        
088                        connection = null;
089                }
090                
091                @Override
092                public void onBinaryMessage(ByteBuffer buf) {
093                        byte[] data = buf.array();
094                        
095                        log.debug("WebSocket connection onBinaryMessage %d", data.length);
096                        
097                        try {
098                                initializeRequest();
099                                
100                                Message[] messages = deserialize(getGravity(), data);
101
102                    log.debug(">> [AMF3 REQUESTS] %s", (Object)messages);
103
104                    Message[] responses = null;
105                    
106                    boolean accessed = false;
107                    int responseIndex = 0;
108                    for (int i = 0; i < messages.length; i++) {
109                        Message message = messages[i];
110                        
111                        // Ask gravity to create a specific response (will be null with a connect request from tunnel).
112                        Message response = getGravity().handleMessage(getFactory(), message);
113                        String channelId = (String)message.getClientId();
114                        
115                        // Mark current channel (if any) as accessed.
116                        if (!accessed)
117                                accessed = getGravity().access(channelId);
118                        
119                        if (response != null) {
120                                if (responses == null)
121                                        responses = new Message[1];
122                                else
123                                        responses = Arrays.copyOf(responses, responses.length+1);
124                                responses[responseIndex++] = response;
125                        }
126                    }
127                    
128                    if (responses != null && responses.length > 0) {
129                            log.debug("<< [AMF3 RESPONSES] %s", (Object)responses);
130                
131                            byte[] resultData = serialize(getGravity(), responses);
132                            
133                            connection.writeBinaryMessage(ByteBuffer.wrap(resultData));
134                    }
135                        }
136                        catch (ClassNotFoundException e) {
137                                log.error(e, "Could not handle incoming message data");
138                        }
139                        catch (IOException e) {
140                                log.error(e, "Could not handle incoming message data");
141                        }
142                        finally {
143                                cleanupRequest();
144                        }
145                }
146
147                @Override
148                protected void onTextMessage(CharBuffer buf) throws IOException {
149                }
150                
151                public int getAckLength() {
152                        return connectAckMessage != null ? connectAckMessage.length : 0;
153                }
154        }
155        
156        private Gravity initializeRequest() {
157                ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), servletContext, sessionId, clientType);
158                return gravity;
159        }
160
161        private static Message[] deserialize(Gravity gravity, byte[] data) throws ClassNotFoundException, IOException {
162                ByteArrayInputStream is = new ByteArrayInputStream(data);
163                try {
164                        ObjectInput amf3Deserializer = gravity.getGraniteConfig().newAMF3Deserializer(is);
165                Object[] objects = (Object[])amf3Deserializer.readObject();
166                Message[] messages = new Message[objects.length];
167                System.arraycopy(objects, 0, messages, 0, objects.length);
168                
169                return messages;
170                }
171                finally {
172                        is.close();
173                }
174        }
175        
176        private static byte[] serialize(Gravity gravity, Message[] messages) throws IOException {
177                ByteArrayOutputStream os = null;
178                try {
179                os = new ByteArrayOutputStream(200*messages.length);
180                ObjectOutput amf3Serializer = gravity.getGraniteConfig().newAMF3Serializer(os);
181                amf3Serializer.writeObject(messages);           
182                os.flush();
183                return os.toByteArray();
184                }
185                finally {
186                        if (os != null)
187                                os.close();
188                }               
189        }
190        
191        private static void cleanupRequest() {
192                GraniteContext.release();
193        }
194        
195        @Override
196        public boolean runReceived(AsyncHttpContext asyncHttpContext) {
197                
198                LinkedList<AsyncMessage> messages = null;
199                ByteArrayOutputStream os = null;
200
201                try {
202                        receivedQueueLock.lock();
203                        try {
204                                // Do we have any pending messages? 
205                                if (receivedQueue.isEmpty())
206                                        return false;
207                                
208                                // Both conditions are ok, get all pending messages.
209                                messages = receivedQueue;
210                                receivedQueue = new LinkedList<AsyncMessage>();
211                        }
212                        finally {
213                                receivedQueueLock.unlock();
214                        }
215                        
216                        if (connection == null)
217                                return false;
218                        
219                        AsyncMessage[] messagesArray = new AsyncMessage[messages.size()];
220                        int i = 0;
221                        for (AsyncMessage message : messages)
222                                messagesArray[i++] = message;
223                        
224                        // Setup serialization context (thread local)
225                        Gravity gravity = getGravity();
226                GraniteContext context = ServletGraniteContext.createThreadInstance(gravity.getGraniteConfig(), gravity.getServicesConfig(), servletContext, sessionId, clientType);
227                
228                os = new ByteArrayOutputStream(500);
229                ObjectOutput amf3Serializer = context.getGraniteConfig().newAMF3Serializer(os);
230                
231                log.debug("<< [MESSAGES for channel=%s] %s", this, messagesArray);
232                
233                amf3Serializer.writeObject(messagesArray);
234                
235                connection.writeBinaryMessage(ByteBuffer.wrap(os.toByteArray()));
236                
237                return true; // Messages were delivered
238                }
239                catch (IOException e) {
240                        log.warn(e, "Could not send messages to channel: %s (retrying later)", this);
241                        
242                        GravityConfig gravityConfig = getGravity().getGravityConfig();
243                        if (gravityConfig.isRetryOnError()) {
244                                receivedQueueLock.lock();
245                                try {
246                                        if (receivedQueue.size() + messages.size() > gravityConfig.getMaxMessagesQueuedPerChannel()) {
247                                                log.warn(
248                                                        "Channel %s has reached its maximum queue capacity %s (throwing %s messages)",
249                                                        this,
250                                                        gravityConfig.getMaxMessagesQueuedPerChannel(),
251                                                        messages.size()
252                                                );
253                                        }
254                                        else
255                                                receivedQueue.addAll(0, messages);
256                                }
257                                finally {
258                                        receivedQueueLock.unlock();
259                                }
260                        }
261                        
262                        return true; // Messages weren't delivered, but http context isn't valid anymore.
263                }
264                finally {
265                        if (os != null) {
266                                try {
267                                        os.close();
268                                }
269                                catch (Exception e) {
270                                        // Could not close bytearray ???
271                                }
272                        }
273                        
274                        // Cleanup serialization context (thread local)
275                        try {
276                                GraniteContext.release();
277                        }
278                        catch (Exception e) {
279                                // should never happen...
280                        }
281                }
282        }
283
284        @Override
285        public void destroy() {
286                try {
287                        super.destroy();
288                }
289                finally {
290                        close();
291                }
292        }
293        
294        public void close() {
295                if (connection != null) {
296                        try {
297                                connection.close(1000, ByteBuffer.wrap("Channel closed".getBytes()));
298                        }
299                        catch (IOException e) {
300                                log.error("Could not close WebSocket connection", e);
301                        }
302                        connection = null;
303                }
304        }
305        
306        @Override
307        protected boolean hasAsyncHttpContext() {
308                return true;
309        }
310
311        @Override
312        protected void releaseAsyncHttpContext(AsyncHttpContext context) {
313        }
314
315        @Override
316        protected AsyncHttpContext acquireAsyncHttpContext() {
317        return null;
318    }           
319}